from utils_libs import *
from utils_dataset import *
from utils_models import *
from utils_general import *


def train_FedAvg(data_obj, act_prob ,learning_rate, batch_size,
                 epoch, com_amount, test_per, weight_decay,
                 model_func, init_model, sch_step, sch_gamma,
                 rand_seed=0, lr_decay_per_round=1):
    n_client=data_obj.n_client

    client_x = data_obj.client_x; client_y=data_obj.client_y
    cent_x = np.concatenate(client_x, axis=0)
    cent_y = np.concatenate(client_y, axis=0)
    n_par = len(get_mdl_params([model_func()])[0])

    init_par_list=get_mdl_params([init_model], n_par)[0]
    client_params_list=np.ones(n_client).astype('float32').reshape(-1, 1) * init_par_list.reshape(1, -1)

    train_perf = np.zeros((com_amount, 2))
    test_perf = np.zeros((com_amount, 2))

    client_models = list(range(n_client))
    avg_model = model_func().to(device)
    avg_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))

    server_model = model_func().to(device)
    server_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))

    for i in range(com_amount):
        inc_seed = 0
        while(True):
            np.random.seed(i + rand_seed + inc_seed)
            act_list    = np.random.uniform(size=n_client)
            act_clients = act_list <= act_prob
            selected_clients = np.sort(np.where(act_clients)[0])
            inc_seed += 1
            if len(selected_clients) != 0:
                break

        print('Communication Round', i + 1, flush = True)
        print('Selected Clients: %s' %(', '.join(['%2d' %item for item in selected_clients])))

        del client_models
        client_models = list(range(n_client))
        for client in selected_clients:
            train_x = client_x[client]
            train_y = client_y[client]
            test_x = False
            test_y = False

            client_models[client] = model_func().to(device)
            client_models[client].load_state_dict(copy.deepcopy(dict(avg_model.named_parameters())))

            for params in client_models[client].parameters():
                params.requires_grad = True
            client_models[client] = train_model(client_models[client], train_x, train_y,
                                            test_x, test_y,
                                            learning_rate * (lr_decay_per_round ** i), batch_size, epoch, 5,
                                            weight_decay,
                                            data_obj.dataset, sch_step, sch_gamma)

            client_params_list[client] = get_mdl_params([client_models[client]], n_par)[0]

        avg_model = set_client_from_params(model_func(), np.mean(client_params_list[selected_clients], axis = 0))
        server_model = set_client_from_params(model_func(), np.mean(client_params_list, axis = 0))

        if (i + 1) % test_per == 0:
            loss_test, acc_test = get_acc_loss(data_obj.test_x, data_obj.test_y,
                                             server_model, data_obj.dataset, 0)
            test_perf[i] = [loss_test, acc_test]

            print("**** Communication all %3d, Test Accuracy: %.4f, Loss: %.4f"
                  %(i+1, acc_test, loss_test), flush = True)

            loss_test, acc_test = get_acc_loss(cent_x, cent_y,
                                             server_model, data_obj.dataset, 0)
            train_perf[i] = [loss_test, acc_test]
            print("**** Communication all %3d, Train Accuracy: %.4f, Loss: %.4f"
                  %(i+1, acc_test, loss_test), flush = True)

        # Freeze model
        for params in avg_model.parameters():
            params.requires_grad = False

    return


def train_SCAFFOLD(data_obj, act_prob ,learning_rate, batch_size, n_minibatch,
                   com_amount, test_per, weight_decay,
                   model_func, init_model, sch_step, sch_gamma,
                    lr_decay_per_round=1, rand_seed=0, global_learning_rate=1):

    n_client=data_obj.n_client
    client_x = data_obj.client_x; client_y=data_obj.client_y
    
    cent_x = np.concatenate(client_x, axis=0)
    cent_y = np.concatenate(client_y, axis=0)

    weight_list = np.asarray([len(client_y[i]) for i in range(n_client)])
    weight_list = weight_list / np.sum(weight_list) * n_client # normalize it
    n_par = len(get_mdl_params([model_func()])[0])
    state_params_diffs = np.zeros((n_client+1, n_par)).astype('float32') #including cloud state
    init_par_list=get_mdl_params([init_model], n_par)[0]
    client_params_list=np.ones(n_client).astype('float32').reshape(-1, 1) * init_par_list.reshape(1, -1) # n_client X n_par

    client_models = list(range(n_client))
    train_perf = np.zeros((com_amount, 2))
    test_perf = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))

    server_model = model_func().to(device)
    server_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))

    for i in range(com_amount):
        inc_seed = 0
        while True:
            np.random.seed(i + rand_seed + inc_seed)
            act_list    = np.random.uniform(size=n_client)
            act_clients = act_list <= act_prob
            selected_clients = np.sort(np.where(act_clients)[0])
            inc_seed += 1
            # Choose at least one client in each synch
            if len(selected_clients) != 0:
                break

        print('Communication Round', i + 1, flush = True)
        print('Selected Clients: %s' %(', '.join(['%2d' %item for item in selected_clients])))

        del client_models

        client_models = list(range(n_client))
        delta_c_sum = np.zeros(n_par)
        prev_params = get_mdl_params([avg_model], n_par)[0]

        for client in selected_clients:
            train_x = client_x[client]
            train_y = client_y[client]

            client_models[client] = model_func().to(device)

            client_models[client].load_state_dict(copy.deepcopy(dict(avg_model.named_parameters())))

            for params in client_models[client].parameters():
                params.requires_grad = True

            # Scale down c
            state_params_diff_curr = torch.tensor(-state_params_diffs[client] + state_params_diffs[-1]/weight_list[client], dtype=torch.float32, device=device)

            client_models[client] = train_scaffold_mdl(client_models[client], model_func, state_params_diff_curr, train_x, train_y,
                learning_rate * (lr_decay_per_round ** i), batch_size, n_minibatch, 5,
                weight_decay, data_obj.dataset, sch_step, sch_gamma)

            curr_model_param = get_mdl_params([client_models[client]], n_par)[0]
            new_c = state_params_diffs[client] - state_params_diffs[-1]/weight_list[client] + 1/n_minibatch/learning_rate/(lr_decay_per_round ** i) * (prev_params - curr_model_param)
            # Scale up delta c
            delta_c_sum += (new_c - state_params_diffs[client])*weight_list[client]
            state_params_diffs[client] = new_c

            client_params_list[client] = curr_model_param

        avg_model_params = global_learning_rate*np.mean(client_params_list[selected_clients], axis = 0) + (1-global_learning_rate)*prev_params

        avg_model = set_client_from_params(model_func().to(device), avg_model_params)

        state_params_diffs[-1] += 1 / n_client * delta_c_sum

        server_model = set_client_from_params(model_func(), np.mean(client_params_list, axis = 0))

        if (i + 1) % test_per == 0:
            loss_test, acc_test = get_acc_loss(data_obj.test_x, data_obj.test_y,
                                             server_model, data_obj.dataset, 0)
            test_perf[i] = [loss_test, acc_test]

            print("**** Communication all %3d, Test Accuracy: %.4f, Loss: %.4f"
                  %(i+1, acc_test, loss_test), flush = True)

            loss_test, acc_test = get_acc_loss(cent_x, cent_y,
                                             server_model, data_obj.dataset, 0)
            train_perf[i] = [loss_test, acc_test]
            print("**** Communication all %3d, Train Accuracy: %.4f, Loss: %.4f"
                  %(i+1, acc_test, loss_test), flush = True)

        # Freeze model
        for params in avg_model.parameters():
            params.requires_grad = False
    return


def train_FedSpeed(data_obj, act_prob,
                  learning_rate, batch_size, epoch, com_amount, test_per,
                  weight_decay, model_func, init_model, alpha_coef,
                  sch_step, sch_gamma, rho, rand_seed=0, lr_decay_per_round=1):
    
    n_client = data_obj.n_client
    client_x = data_obj.client_x; client_y=data_obj.client_y
    
    cent_x = np.concatenate(client_x, axis=0)
    cent_y = np.concatenate(client_y, axis=0)
    
    weight_list = np.asarray([len(client_y[i]) for i in range(n_client)])
    weight_list = weight_list / np.sum(weight_list) * n_client
    
    train_perf = np.zeros((com_amount, 2))
    test_perf = np.zeros((com_amount, 2))
    
    n_par = len(get_mdl_params([model_func()])[0])
    
    hist_params_diffs = np.zeros((n_client, n_par)).astype('float32')
    init_par_list=get_mdl_params([init_model], n_par)[0]
    client_params_list  = np.ones(n_client).astype('float32').reshape(-1, 1) * init_par_list.reshape(1, -1) # n_client X n_par
    client_models = list(range(n_client))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))

    server_model = model_func().to(device)
    server_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))

    all_model = model_func().to(device)
    all_model.load_state_dict(copy.deepcopy(dict(init_model.named_parameters())))
    all_model_param = get_mdl_params([all_model], n_par)[0]

    for i in range(com_amount):
        inc_seed = 0
        while(True):
            np.random.seed(i + rand_seed + inc_seed)
            act_list    = np.random.uniform(size=n_client)
            act_clients = act_list <= act_prob
            selected_clients = np.sort(np.where(act_clients)[0])
            inc_seed += 1
            if len(selected_clients) != 0:
                break

        print('Communication Round', i + 1, flush = True)
        print('Selected Clients: %s' %(', '.join(['%2d' %item for item in selected_clients])))
        all_model_param_tensor = torch.tensor(all_model_param, dtype=torch.float32, device=device)

        del client_models
        client_models = list(range(n_client))

        for client in selected_clients:
            train_x = client_x[client]
            train_y = client_y[client]

            client_models[client] = model_func().to(device)

            model = client_models[client]
            # Warm start from current avg model
            model.load_state_dict(copy.deepcopy(dict(all_model.named_parameters())))
            for params in model.parameters():
                params.requires_grad = True
            # Scale down
            alpha_coef_adpt = alpha_coef / weight_list[client] # adaptive alpha coef
            hist_params_diffs_curr = torch.tensor(hist_params_diffs[client], dtype=torch.float32, device=device)
            client_models[client] = train_model_speed(model, model_func, alpha_coef_adpt,
                                                 all_model_param_tensor, hist_params_diffs_curr,
                                                 train_x, train_y, learning_rate * (lr_decay_per_round ** i),
                                                 batch_size, epoch, 5, weight_decay,
                                                 data_obj.dataset, sch_step, sch_gamma, rho, print_verbose=False)
            curr_model_par = get_mdl_params([client_models[client]], n_par)[0]
            hist_params_diffs[client] += curr_model_par-all_model_param
            client_params_list[client] = curr_model_par

        avg_mdl_param_sel = np.mean(client_params_list[selected_clients], axis = 0)
        all_model_param = avg_mdl_param_sel + np.mean(hist_params_diffs, axis=0)
        all_model     = set_client_from_params(model_func().to(device), all_model_param)
        server_model  = set_client_from_params(model_func(), np.mean(client_params_list, axis = 0))

        if (i + 1) % test_per == 0:
            loss_test, acc_test = get_acc_loss(data_obj.test_x, data_obj.test_y,
                                             server_model, data_obj.dataset, 0)
            print("**** Cur All Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                  %(i+1, acc_test, loss_test))
            test_perf[i] = [loss_test, acc_test]

            loss_test, acc_test = get_acc_loss(cent_x, cent_y,
                                             server_model, data_obj.dataset, 0)
            print("**** Cur All Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
                  %(i+1, acc_test, loss_test), flush = True)
            
            train_perf[i] = [loss_test, acc_test]
        # Freeze model
        for params in server_model.parameters():
            params.requires_grad = False
    return

